{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp models.XResNet1dPlus"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XResNet1dPlus\n",
    "\n",
    "> This is a modified version of fastai's XResNet model in github. Changes include:\n",
    "* API is modified to match the default timeseriesAI's API.\n",
    "* (Optional) Uber's CoordConv 1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from tsai.imports import *\n",
    "from tsai.models.layers import *\n",
    "from tsai.models.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class XResNet1dPlus(nn.Sequential):\n",
    "    @delegates(ResBlock1dPlus)\n",
    "    def __init__(self, block, expansion, layers, fc_dropout=0.0, c_in=3, n_out=1000, stem_szs=(32,32,64),\n",
    "                 widen=1.0, sa=False, act_cls=defaults.activation, ks=3, stride=2, coord=False, **kwargs):\n",
    "        store_attr('block,expansion,act_cls,ks')\n",
    "        if ks % 2 == 0: raise Exception('kernel size has to be odd!')\n",
    "        stem_szs = [c_in, *stem_szs]\n",
    "        stem = [ConvBlock(stem_szs[i], stem_szs[i+1], ks=ks, coord=coord, stride=stride if i==0 else 1,\n",
    "                          act=act_cls)\n",
    "                for i in range(3)]\n",
    "\n",
    "        block_szs = [int(o*widen) for o in [64,128,256,512] +[256]*(len(layers)-4)]\n",
    "        block_szs = [64//expansion] + block_szs\n",
    "        blocks    = self._make_blocks(layers, block_szs, sa, coord, stride, **kwargs)\n",
    "        backbone = nn.Sequential(*stem, MaxPool(ks=ks, stride=stride, padding=ks//2, ndim=1), *blocks)\n",
    "        head = nn.Sequential(AdaptiveAvgPool(sz=1, ndim=1), Flatten(), nn.Dropout(fc_dropout),\n",
    "                             nn.Linear(block_szs[-1]*expansion, n_out))\n",
    "        super().__init__(OrderedDict([('backbone', backbone), ('head', head)]))\n",
    "        self._init_cnn(self)\n",
    "\n",
    "    def _make_blocks(self, layers, block_szs, sa, coord, stride, **kwargs):\n",
    "        return [self._make_layer(ni=block_szs[i], nf=block_szs[i+1], blocks=l, coord=coord, \n",
    "                                 stride=1 if i==0 else stride, sa=sa and i==len(layers)-4, **kwargs)\n",
    "                for i,l in enumerate(layers)]\n",
    "\n",
    "    def _make_layer(self, ni, nf, blocks, coord, stride, sa, **kwargs):\n",
    "        return nn.Sequential(\n",
    "            *[self.block(self.expansion, ni if i==0 else nf, nf, coord=coord, stride=stride if i==0 else 1,\n",
    "                      sa=sa and i==(blocks-1), act_cls=self.act_cls, ks=self.ks, **kwargs)\n",
    "              for i in range(blocks)])\n",
    "    \n",
    "    def _init_cnn(self, m):\n",
    "        if getattr(self, 'bias', None) is not None: nn.init.constant_(self.bias, 0)\n",
    "        if isinstance(self, (nn.Conv1d,nn.Conv2d,nn.Conv3d,nn.Linear)): nn.init.kaiming_normal_(self.weight)\n",
    "        for l in m.children(): self._init_cnn(l)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def _xresnetplus(expansion, layers, **kwargs):\n",
    "    return XResNet1dPlus(ResBlock1dPlus, expansion, layers, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2, 2,  2, 2], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3, 4,  6, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 4,  6, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d101plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 4, 23, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d152plus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3, 8, 36, 3], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2,2,2,2,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3,4,6,3,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50_deepplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3,4,6,3,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d18_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [2,2,1,1,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d34_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(1, [3,4,6,3,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)\n",
    "@delegates(ResBlock)\n",
    "def xresnet1d50_deeperplus (c_in, c_out, act=nn.ReLU, **kwargs): return _xresnetplus(4, [3,4,6,3,1,1,1,1], c_in=c_in, n_out=c_out, act_cls=act, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2051, -0.4082],\n",
       "        [ 0.1320, -0.2454],\n",
       "        [ 0.1718, -0.5900],\n",
       "        [ 0.3423, -0.1778],\n",
       "        [ 0.2053, -0.5386],\n",
       "        [ 0.2830, -0.5609],\n",
       "        [ 0.4794, -0.4538],\n",
       "        [ 0.1914,  0.1688],\n",
       "        [ 0.3116, -0.3382],\n",
       "        [ 0.0513, -0.3032],\n",
       "        [-0.0072, -0.4724],\n",
       "        [ 0.0463, -0.3346],\n",
       "        [ 0.4596, -0.7458],\n",
       "        [ 0.2832, -0.5932],\n",
       "        [ 0.2439, -0.0979],\n",
       "        [ 0.2160, -0.4570],\n",
       "        [ 0.2408, -0.2508],\n",
       "        [ 0.0452, -0.0187],\n",
       "        [ 0.4669, -0.1605],\n",
       "        [ 0.3136, -0.6333],\n",
       "        [-0.0930, -0.0047],\n",
       "        [-0.0522, -0.4743],\n",
       "        [-0.0063, -0.2872],\n",
       "        [ 0.5206, -0.2603],\n",
       "        [ 0.0385, -0.8225],\n",
       "        [ 0.1682, -0.7900],\n",
       "        [ 0.1275, -0.4700],\n",
       "        [ 0.0180, -0.4667],\n",
       "        [ 0.1500, -0.3306],\n",
       "        [ 0.4787, -0.5689],\n",
       "        [ 0.3416, -0.7539],\n",
       "        [ 0.0963, -0.5110]], grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = xresnet1d18plus(3, 2, coord=True)\n",
    "x = torch.rand(32, 3, 50)\n",
    "net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 xresnet1d18plus\n",
      "1 xresnet1d34plus\n",
      "2 xresnet1d50plus\n",
      "3 xresnet1d18_deepplus\n",
      "4 xresnet1d34_deepplus\n",
      "5 xresnet1d50_deepplus\n",
      "6 xresnet1d18_deeperplus\n",
      "7 xresnet1d34_deeperplus\n",
      "8 xresnet1d50_deeperplus\n"
     ]
    }
   ],
   "source": [
    "bs, c_in, seq_len = 2, 4, 32\n",
    "c_out = 2\n",
    "x = torch.rand(bs, c_in, seq_len)\n",
    "archs = [\n",
    "    xresnet1d18plus, xresnet1d34plus, xresnet1d50plus, \n",
    "    xresnet1d18_deepplus, xresnet1d34_deepplus, xresnet1d50_deepplus, xresnet1d18_deeperplus,\n",
    "    xresnet1d34_deeperplus, xresnet1d50_deeperplus\n",
    "#     # Long test\n",
    "#     xresnet1d101, xresnet1d152,\n",
    "]\n",
    "for i, arch in enumerate(archs):\n",
    "    print(i, arch.__name__)\n",
    "    test_eq(arch(c_in, c_out, sa=True, act=Mish, coord=True)(x).shape, (bs, c_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = xresnet1d34plus(4, 2, act=Mish)\n",
    "test_eq(len(get_layers(m, is_bn)), 38)\n",
    "test_eq(check_weight(m, is_bn)[0].sum(), 22)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# hide\n",
    "out = create_scripts()\n",
    "beep(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
